# -*- coding: utf-8 -*-
"""
Created on Tue Jan  2 23:38:57 2018

"""

import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from layers import encoder_block,convlstm_layer
import math
import numpy as np
from torch.optim.lr_scheduler import _LRScheduler,CyclicLR


class PhyCRNet(nn.Module):
    def __init__(self,input_channels,hidden_channels,output_channels,
                 input_kernel_size,input_stride,padding_mode,num_layers,dropout,
                 warmup_updates,tot_updates,peak_lr,end_lr,power,weight_decay):

        super(PhyCRNet,self).__init__()
        self.warmup_updates=warmup_updates
        self.tot_updates=tot_updates
        self.peak_lr=peak_lr  #
        self.end_lr=end_lr
        self.power=power
        self.weight_decay=weight_decay
        self.dropout=dropout

        self.input_channels=[input_channels]+hidden_channels
        self.hidden_channels=hidden_channels
        self.input_kernel_size=input_kernel_size
        self.input_stride=input_stride
        # padding_mode='reflect'

        self.num_convlstm=num_layers
        self.convlstm=nn.ModuleList()
        for i in range(0,self.num_convlstm,3):
            self.convlstm.append(convlstm_layer(
                                        self.input_channels[i],
                                        self.hidden_channels[i],
                                        self.hidden_channels[i+1],
                                        self.hidden_channels[i+2],
                                        input_kernel_size,
                                        input_stride,
                                        padding_mode,
                                        self.dropout))
        

    def forward(self,x):
        '''
        initial_state include hidden state and cell state, shape are B_T_C_H_W
        x shape is B_T_C_H_W
        '''
        for i in range(len(self.convlstm)):
            x=self.convlstm[i](x)

        return x
    

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.peak_lr,weight_decay=self.weight_decay)

        lr_scheduler = PolynomialDecayLR(
                optimizer,
                warmup_updates=self.warmup_updates,
                tot_updates=self.tot_updates,
                lr=self.peak_lr,
                end_lr=self.end_lr,
                power=self.power)
        
        # lr_scheduler=CyclicLR(optimizer,base_lr=self.end_lr,max_lr=self.peak_lr,step_size_up=100,
        #                       mode='triangular2',cycle_momentum=False)
    
        return optimizer, lr_scheduler



class PolynomialDecayLR(_LRScheduler):
    """
    learning rate scheduler
    """

    def __init__(self, optimizer, warmup_updates, tot_updates, lr, end_lr,
                 power, last_epoch=-1, verbose=False):
        
        self.warmup_updates = warmup_updates
        self.tot_updates = tot_updates
        self.lr = lr
        self.end_lr = end_lr
        self.power = power
        super(PolynomialDecayLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if self._step_count <= self.warmup_updates:
            self.warmup_factor = self._step_count / float(self.warmup_updates)
            lr = self.warmup_factor * self.lr
        elif self._step_count >= self.tot_updates:
            lr = self.end_lr
        else:
            warmup = self.warmup_updates
            lr_range = self.lr - self.end_lr
            pct_remaining = 1 - (self._step_count - warmup) / (self.tot_updates - warmup)
            lr = lr_range * pct_remaining ** (self.power) + self.end_lr

        return [lr for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        assert False


class MaskedMSELoss(torch.nn.Module):
    def __init__(self):
        super(MaskedMSELoss, self).__init__()

    def forward(self, inputs, target, mask):
        '''
        inputs,target,mask shape is B_T_C_H_W
        '''
        result=torch.FloatTensor([0.]).to(device=inputs.device)
        if torch.sum(mask)==0:
            return result
        for i in range(inputs.shape[2]):
            diff2 = (torch.flatten(inputs[:,:,i,:,:]) - torch.flatten(target[:,:,i,:,:])) ** 2.0 * torch.flatten(mask[:,:,i,:,:])
            result =result+ torch.sum(diff2) / torch.sum(mask[:,:,i,:,:])
        return result


class weightMSELoss(torch.nn.Module):
    def __init__(self):
        super(weightMSELoss, self).__init__()

    def forward(self, inputs, target, weight):
        '''
        inputs,weight shape is B_T_C_H_W
        targte shape is B_T_F_1_1
        '''
        result = F.mse_loss((inputs*weight).sum(dim=-1).sum(dim=-1), target.squeeze(-1).squeeze(-1))
        return result


class MaskedL1Loss(torch.nn.Module):
    def __init__(self):
        super(MaskedL1Loss, self).__init__()

    def forward(self, inputs, target, mask):
        '''
        inputs,target,mask shape is B_T_C_H_W
        '''
        result=torch.FloatTensor([0.]).to(device=inputs.device)
        for i in range(inputs.shape[-1]):
            diff2 = abs(torch.flatten(inputs[:,:,i,:,:]) - torch.flatten(target[:,:,i,:,:]))* torch.flatten(mask[:,:,i,:,:])
            result =result+ torch.sum(diff2) / torch.sum(mask[:,:,i,:,:])
        return result


class weightL1Loss(torch.nn.Module):
    def __init__(self):
        super(weightL1Loss, self).__init__()

    def forward(self, inputs, target, weight):
        '''
        inputs,weight shape is B_T_C_H_W
        targte shape is B_T_F_1_1
        '''
        result = F.l1_loss((inputs*weight).sum(dim=-1).sum(-1), target.squeeze(-1).squeeze(-1))
        return result
    
    